1 Introduction
- Costruire un modello in grado di filtrare i commenti degli utenti in base al grado di dannosità del linguaggio.
- Preprocessare il testo eliminando l’insieme di token che non danno contributo significativo a livello semantico.
- Trasformare il corpus testuale in sequenze.
- Costruire un modello di Deep Learning comprendente dei layer ricorrenti per un task di classificazione multilabel.
In prediction time, il modello deve ritornare un vettore contenente un 1 o uno 0 in corrispondenza di ogni label presente nel dataset (toxic, severe_toxic, obscene, threat, insult, identity_hate). In questo modo, un commento non dannoso sarà classificato da un vettore di soli 0 [0,0,0,0,0,0]. Al contrario, un commento pericoloso presenterà almeno un 1 tra le 6 labels.
2 Setup
Leveraging Quarto and RStudio, I will setup an R and Python enviroment.
2.1 Import R libraries
Import R libraries. These will be used for both the rendering of the document and data analysis. The reason is I prefer ggplot2 over matplotlib. I will also use colorblind safe palettes.
2.2 Import Python packages
Code
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
import keras_nlp
from keras.backend import clear_session
from keras.models import Model, load_model
from keras.layers import TextVectorization, Input, Dense, Embedding, Dropout, GlobalAveragePooling1D, LSTM, Bidirectional, GlobalMaxPool1D, Flatten, Attention
from keras.metrics import Precision, Recall, AUC, SensitivityAtSpecificity, SpecificityAtSensitivity, F1Score
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import multilabel_confusion_matrix, classification_report, ConfusionMatrixDisplay, precision_recall_curve, f1_score, recall_score, roc_auc_scoreCreate a Config class to store all the useful parameters for the model and for the project.
2.3 Class Config
I created a class with all the basic configuration of the model, to improve the readability.
Code
class Config():
def __init__(self):
self.url = "https://s3.eu-west-3.amazonaws.com/profession.ai/datasets/Filter_Toxic_Comments_dataset.csv"
self.max_tokens = 20000
self.output_sequence_length = 911 # check the analysis done to establish this value
self.embedding_dim = 128
self.batch_size = 32
self.epochs = 100
self.temp_split = 0.3
self.test_split = 0.5
self.random_state = 42
self.total_samples = 159571 # total train samples
self.train_samples = 111699
self.val_samples = 23936
self.features = 'comment_text'
self.labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
self.new_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate', "clean"]
self.label_mapping = {label: i for i, label in enumerate(self.labels)}
self.new_label_mapping = {label: i for i, label in enumerate(self.labels)}
self.path = "/Users/simonebrazzi/R/blog/posts/toxic_comment_filter/history/f1score/"
self.model = self.path + "model_f1.keras"
self.checkpoint = self.path + "checkpoint.lstm_model_f1.keras"
self.history = self.path + "lstm_model_f1.xlsx"
self.metrics = [
Precision(name='precision'),
Recall(name='recall'),
AUC(name='auc', multi_label=True, num_labels=len(self.labels)),
F1Score(name="f1", average="macro")
]
def get_early_stopping(self):
early_stopping = keras.callbacks.EarlyStopping(
monitor="val_f1", # "val_recall",
min_delta=0.2,
patience=10,
verbose=0,
mode="max",
restore_best_weights=True,
start_from_epoch=3
)
return early_stopping
def get_model_checkpoint(self, filepath):
model_checkpoint = keras.callbacks.ModelCheckpoint(
filepath=filepath,
monitor="val_f1", # "val_recall",
verbose=0,
save_best_only=True,
save_weights_only=False,
mode="max",
save_freq="epoch"
)
return model_checkpoint
def find_optimal_threshold_cv(self, ytrue, yproba, metric, thresholds=np.arange(.05, .35, .05), n_splits=7):
# instantiate KFold
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
threshold_scores = []
for threshold in thresholds:
cv_scores = []
for train_index, val_index in kf.split(ytrue):
ytrue_val = ytrue[val_index]
yproba_val = yproba[val_index]
ypred_val = (yproba_val >= threshold).astype(int)
score = metric(ytrue_val, ypred_val, average="macro")
cv_scores.append(score)
mean_score = np.mean(cv_scores)
threshold_scores.append((threshold, mean_score))
# Find the threshold with the highest mean score
best_threshold, best_score = max(threshold_scores, key=lambda x: x[1])
return best_threshold, best_score
config = Config()3 Data
The dataset is accessible using tf.keras.utils.get_file to get the file from the url. N.B. For reproducibility purpose, I also downloaded the dataset. There was time in which the link was not available.
# A tibble: 5 × 8
comment_text toxic severe_toxic obscene threat insult identity_hate
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 "Explanation\nWhy the … 0 0 0 0 0 0
2 "D'aww! He matches thi… 0 0 0 0 0 0
3 "Hey man, I'm really n… 0 0 0 0 0 0
4 "\"\nMore\nI can't mak… 0 0 0 0 0 0
5 "You, sir, are my hero… 0 0 0 0 0 0
# ℹ 1 more variable: sum_injurious <dbl>
Lets create a clean variable for EDA purpose: I want to visually see how many observation are clean vs the others labels.
3.1 EDA
First a check on the dataset to find possible missing values and imbalances.
3.1.1 Frequency
Code
library(reticulate)
df_r <- py$df
new_labels_r <- py$config$new_labels
df_r_grouped <- df_r %>%
select(all_of(new_labels_r)) %>%
pivot_longer(
cols = all_of(new_labels_r),
names_to = "label",
values_to = "value"
) %>%
group_by(label) %>%
summarise(count = sum(value)) %>%
mutate(freq = round(count / sum(count), 4))
df_r_grouped# A tibble: 7 × 3
label count freq
<chr> <dbl> <dbl>
1 clean 143346 0.803
2 identity_hate 1405 0.0079
3 insult 7877 0.0441
4 obscene 8449 0.0473
5 severe_toxic 1595 0.0089
6 threat 478 0.0027
7 toxic 15294 0.0857
3.1.2 Barchart
Code
library(reticulate)
barchart <- df_r_grouped %>%
ggplot(aes(x = reorder(label, count), y = count, fill = label)) +
geom_col() +
labs(
x = "Labels",
y = "Count"
) +
# sort bars in descending order
scale_x_discrete(limits = df_r_grouped$label[order(df_r_grouped$count, decreasing = TRUE)]) +
scale_fill_brewer(type = "seq", palette = "RdYlBu") +
theme_minimal()
ggplotly(barchart)It is visible how much the dataset in imbalanced. This means it could be useful to check for the class weight and use this argument during the training.
It is clear that most of our text are clean. We are talking about 0.8033 of the observations which are clean. Only 0.1967 are toxic comments.
3.2 Sequence lenght definition
To convert the text in a useful input for a NN, it is necessary to use a TextVectorization layer. See the Section 4 section.
One of the method is output_sequence_length: to better define it, it is useful to analyze our text length. To simulate what the model we do, we are going to remove the punctuation and the new lines from the comments.
3.2.1 Summary
Code
# A tibble: 1 × 6
Min. `1st Qu.` Median Mean `3rd Qu.` Max.
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 4 91 196 378. 419 5000
3.2.2 Boxplot
Code
library(reticulate)
boxplot <- df_r %>%
mutate(
comment_text_clean = comment_text %>%
tolower() %>%
str_remove_all("[[:punct:]]") %>%
str_replace_all("\n", " "),
text_length = comment_text_clean %>% str_count()
) %>%
# pull(text_length) %>%
ggplot(aes(y = text_length)) +
geom_boxplot() +
theme_minimal()
ggplotly(boxplot)3.2.3 Histogram
Code
library(reticulate)
df_ <- df_r %>%
mutate(
comment_text_clean = comment_text %>%
tolower() %>%
str_remove_all("[[:punct:]]") %>%
str_replace_all("\n", " "),
text_length = comment_text_clean %>% str_count()
)
Q1 <- quantile(df_$text_length, 0.25)
Q3 <- quantile(df_$text_length, 0.75)
IQR <- Q3 - Q1
upper_fence <- as.integer(Q3 + 1.5 * IQR)
histogram <- df_ %>%
ggplot(aes(x = text_length)) +
geom_histogram(bins = 50) +
geom_vline(aes(xintercept = upper_fence), color = "red", linetype = "dashed", linewidth = 1) +
theme_minimal() +
xlab("Text Length") +
ylab("Frequency") +
xlim(0, max(df_$text_length, upper_fence))
ggplotly(histogram)Considering all the above analysis, I think a good starting value for the output_sequence_length is 911, the upper fence of the boxplot. In the last plot, it is the dashed red vertical line.. Doing so, we are removing the outliers, which are a small part of our dataset.
3.3 Dataset
Now we can split the dataset in 3: train, test and validation sets. Considering there is not a function in sklearn which lets split in these 3 sets, we can do the following: - split between a train and temporary set with a 0.3 split. - split the temporary set in 2 equal sized test and val sets.
Code
x = df[config.features].values
y = df[config.labels].values
xtrain, xtemp, ytrain, ytemp = train_test_split(
x,
y,
test_size=config.temp_split, # .3
random_state=config.random_state
)
xtest, xval, ytest, yval = train_test_split(
xtemp,
ytemp,
test_size=config.test_split, # .5
random_state=config.random_state
)xtrain shape: py$xtrain.shape ytrain shape: py$ytrain.shape xtest shape: py$xtest.shape ytest shape: py$ytest.shape xval shape: py$xval.shape yval shape: py$yval.shape
The datasets are created using the tf.data.Dataset function. It creates a data input pipeline. The tf.data API makes it possible to handle large amounts of data, read from different data formats, and perform complex transformations. The tf.data.Dataset is an abstraction that represents a sequence of elements, in which each element consists of one or more components. Here each dataset is creates using from_tensor_slices. It create a tf.data.Dataset from a tuple (features, labels). .batch let us work in batches to improve performance, while .prefetch overlaps the preprocessing and model execution of a training step. While the model is executing training step s, the input pipeline is reading the data for step s+1. Check the documentation for further informations.
Code
train_ds = (
tf.data.Dataset
.from_tensor_slices((xtrain, ytrain))
.shuffle(xtrain.shape[0])
.batch(config.batch_size)
.prefetch(tf.data.experimental.AUTOTUNE)
)
test_ds = (
tf.data.Dataset
.from_tensor_slices((xtest, ytest))
.batch(config.batch_size)
.prefetch(tf.data.experimental.AUTOTUNE)
)
val_ds = (
tf.data.Dataset
.from_tensor_slices((xval, yval))
.batch(config.batch_size)
.prefetch(tf.data.experimental.AUTOTUNE)
)Code
train_ds cardinality: 3491
val_ds cardinality: 748
test_ds cardinality: 748
Check the first element of the dataset to be sure that the preprocessing is done correctly.
(array([b'or, perhaps can it be marked as a stub so that moderators will know that there is more work to be done on it?',
b'"\n\n Let\'s collaborate \n\nAshish Jee Namaste,\nGood to know about your activities in Wikipedia. As you know we are recently recognized as ""Wikimedians of Nepal"" from the WMF. There are some wikipedia editors editing in Nepali and other local languages. There are little editors from Nepal in touch who are editing in English and other language Wikipedia. If we could be in touch, we could have good collaboration. Regards. \n Thanks Ganesh g, how can i be in that group ? "',
b'"\nI understand you, but you also should understand me. You shouldn\'t revert my work. As you can see, I don\'t disrupt your edits at articles Draza Mihailovic, Josip Broz Tito, or at Template:Yugoslav Axis collaborationism. I realy respect your effort and work there, and I don\'t revert it. I have no desire to get blocked, or to drag you into another blockade. I want to resolve all this disputes with you. As you can see I already changed Template:Politics of Yugoslavia to a neutral, non-image version. These are my proposals to you:\n\n1) If we put in place your version of the list of Yugoslav Prime Ministers, it will remove Drago Marusic as Prime Minister in 1945, between Ivan Subasic and Tito. If you agree to left him on the list, this dispute is over. If we agree about Marusic, we can also agree about the issue ""Kingdom or DF Yugoslavia"".\n\n2) Federal Republic of Yugoslavia was a complete different state than Serbia and Montenegro. It wasn\'t like Kingdom of Serbs, Croats and Slovenes and Kingdom of Yugoslavia, or Federal People\'s Republic of Yugoslavia and Socialist Federal Republic of Yugoslavia. In 2003, after the adoption of the Constitutional Charter, whole system of government was changed. Even the position of the Prime Minister was abolished, and merged with the position of the President. Those two countries realy need a separate articles. Also, their armies need a separate articles.\n\n3) Articles President of the Federal Republic of Yugoslavia and President of Serbia and Montenegro should remain separate. The last President of FRY was Vojislav Kostunica, from 2000 until 2003 when he was replaced by Svetozar Marovic of SCG (2003-2006). Those two offices was very different (President of FRY wasn\'t both head of state and head of government, like President of SCG). \n\n4) Prime Minister of the Federal Republic of Yugoslavia wasn\'t ""Prime Minister of Serbia and Montenegro"", but of FRY. That article should retain its current name.\n\n5) As you can see, yesterday I changed Template:Politics of Yugoslavia to a neutral, non-image version. I hope that you agree with this version.\n\nP.S. \n\nI\'m sorry if you think that people from the Balkans are uncivilized. It\'s not true, believe me. We aren\'t better or worse than other peoples. "',
b"I've corrected the climate records for the chart. The source led to Dell Regional High School in New Jersey, not anywhere in New York City. This was probably done so that the introduction of NYC's climate could say it lies entirely in the humid subtropical zone, with the coldest month averaging 0\xc2\xb0 C. But the weather.com record for New York City shows the coldest month with an average high of 36 and a low of 23, averaging out to -1.3\xc2\xb0 C, placing it in the transition zone between humid subtropical and humid continental. I will make that change as well.",
b"Template:Infobox very fit for this section.\n My be\xe2\x80\xa6 As min 3 pic is normal, although 5 pic is normal too.\n I think this interview need.\n Why? It is not necessary.\n It's right.\n Meaning section explained base principes of construction of video.\n \xe2\x80\x9cAwards\xe2\x80\x9d in music video section and \xe2\x80\x9cAwards\xe2\x80\x9d in article is not the same thing. \xe2\x80\x9cAwards\xe2\x80\x9d in article is ALL awards for \xe2\x80\x9cUmbrella\xe2\x80\x9d. \xe2\x80\x9cAwards\xe2\x80\x9d in section is awards ONLY for music video.",
b'"\n\n Enjoy! \n\n \n\n has smiled at you! Smiles promote WikiLove and hopefully this one has made your day better. Spread the WikiLove by smiling to someone else, whether it be someone you have had disagreements with in the past or a good friend. Happy editing! Smile at others by adding to their talk page with a friendly message. ."',
b'direct conflict of interest, i.e., where',
b'"\nOk, then the section should be condensed to discuss the aspects in particular where D contributes to generic programming. If Dmeranda is correct in that ""static if"" is a completely other take on generics then that point need to be discussed and exemplified, and I\'m not certain if the factorial example is an adequate illustration of this since I failed to get that point. Also, the alias parameters seems to be of relevance and a discussion about these should be added. Static derival of return types risks venturing into pure language comparisons since C++0x will be equipped with a of_type (forgot what it is to be named) static operator and other languages have this facility too. Thus in short, the section should be condensed, sharpened to drive the relevant points home, and a better example should be found (invented?). I am doing a bold edit to test your reaction: since the ""D templates"" section is in all a commentary on the ""C++ templates"" section, I\'m moving it to that section as a corollary. It makes more sense like tat the way it is written atm. Feel free to revert if you disapprove of this. "',
b'"\n\n This is weird... \n\nThanks for improving the article Ayaka Asai, which I created (together with Tomoyo Kurosawa and Moe Toyota) a few months ago. I noticed a very odd error though: her agency profile says that she had a role in the 2013 Pok\xc3\xa9mon theatrical short Pikachu and Its Eevee Friends as ""Glacia"" (Glaceon). As a Pok\xc3\xa9mon fan myself, I immediately noticed some discrepancies: this website and the movie credits give the voice actress for the said role as Akeno Watanabe (itself an article needing immediate attention). I\'m confused (in fact, the said role was included in my edit that created the article, but as you can see with this edit, I removed it for the reasons I mentioned above). could the agency have made a mistake? tccsdnew "',
b"If it's not them suffering; why mention it?",
b'I think that message was to SHanes, am not expecting responses from others.',
b"I'd say that we should keep this from enlarging to a debate on this discussion page since Wikipedia isn't a forum. But on the other hand, I can see how Aerith being a WiR is certainly arguable, but nevertheless, it isn't really all that important to the article so I don't think there's any reason to include this part in. \xe2\x80\x94",
b'"The Epoch Times is a shameless propaganda news service, who claims to counter ""censorship"" by spewing conspiracy theories and counter-propaganda about the CCP. It often publishes articles relating to ""natural phenomenons"" in China, ""predicting"" the fall of the CCP, as well as Jiang Zemin\'s supposed death in 2003. It\'s campaigns to smear the CCP is not followed at all in the media, such as it\'s campaigns to ""sue"" Jiang Zemin. It claims to support democracy, yet intolerant of criticism, censoring posts in their online forums and blanking pages on Wikipedia . The so-called FLG death camp is in fact a Malaysian joint venture sponsored by the Malaysian government . I\'m not going to ""save"" the world with a group of cultists. \n\n"',
b'"\n Yes, Apteva ... I do wish you yourself would learn to be civil and collegial. Your snide, sarcastic and snippy comments - plus the badgering of anyone who dares speak against you - show that civil and collegial are either not in your personal dictionary, or that you have long forgotten their meaning (\xe2\x9c\x89\xe2\x86\x92\xe2\x86\x90\xe2\x9c\x8e) "',
b"Participant alert regarding Wikiproject on Advertising\nThe Wikiproject No Ads, created as a backlash against the Answers.com deal, has served an important function in providing a space for users to express their disagreement with the Foundation proposal. While the current controversies about userboxes raise questions about political and social advocacy on Wikipedia, there should be greater flexibility regarding advocacy about Wikipedia in the Wikipedia namespace. Reported and linked by Slashdot and other press sources as a unique and spontaneous occurence in Wikipedia history, it has apparently had some impact as, despite being scheduled to begin in January, not a peep has been heard about the trial and proposed sponsored link since the deal's controversial announcement months ago. Currently, however, there is an attempt to delete the project or move it off Wikipedia altogether. Since the Foundation has provided no additional information and has not attempted to answer the specific questions that participants in the project raised, it is unclear if the Answers.com deal has been abandoned or simply delayed. Until the situation becomes more clear, I believe the group should still have a place in the Wikipedia namespace. Sincerely,",
b'Check out my youtube account and leave a comment!! http://www.youtube.com/profile?user=NaturalNeil',
b'"\nPer WP:NLIST: ""Furthermore, every entry in any such list requires a reliable source attesting to the fact that the named person is a member of the listed group."" WP:BURDEN say when you add the info, you provide the source. The Dissident Aggressor "',
b'"\nPersonally, I have never seen it before, ever. But if it is common practice and I am in error, then please accept my apologies Chris. When I saw it occur, it was canvassing to me. (talk \xe2\x99\xa6 contribs) @ \'\' "',
b'"\n\n assistance \n\nI seek administrator assistance. User:Emerson7 monitors my edits and reverts them whevener he feels like it. In the articles ""List of billionaires (2006)"" and ""List of billionaires (2007)"" he reverts my addition of the Lebanese flag next to Carlos Slim Helu\'s name, and does it sometimes with using an IP instead of his username. Get him off my back, if you will. As much as it may bother ""Emerson"", Carlos Slim Helu\'s parents are both Lebanese, his holds the Lebanese citizenship and visits his country very often. "',
b'fuk you fucked up sinhalese motherfucking assholes wanna suck my cock bitch ass . jeyasinghe',
b'Sounds fine to me. Thank you for your work. The comments to The Master on his talk page were directed to his specific edits. Thank you, again.',
b'Thank you for experimenting with the page Chalastra on Wikipedia. Your test worked, and has been reverted or removed. Please use the sandbox for any other tests you want to do. Take a look at the welcome page if you would like to learn more about contributing to our encyclopedia. A link to the edit I have reverted can be found here: link. If you believe this edit should not have been reverted, please contact me.',
b'"Did you also notice that the information transmitted in ""pevc.dowjones.com"" was taken directly from a chess.com blog post? You are not very good at this, are you? "',
b'Please stop vandalizing\n\nPlease do not add nonsense to Wikipedia. It is considered vandalism. If you would like to experiment, use the sandbox. Thank you.',
b"Thanks. Yes, a bad production might be particularly notable! I guess this was the filmed version of that show. Kingdom looks quite like him there, doesn't he.",
b'"\n\nSupport - ""chairman"" is more common and unforced, and carries no implications of masculine sex, only gender (and anyway English doesn\'t really have grammatical gender). See Merriam-Webster\'s (""the presiding officer of a meeting, organization, committee, or event; the administrative officer of a department of instruction (as in a college)"") and Oxford (""The occupier of a chair of authority; spec. the person who is chosen to preside over a meeting, to conduct its proceedings, and who occupies the chair or seat provided for this function.""). Talk "',
b"Have I made myself clear yet? It should be pretty fucking obvious by now that I don't give a shit about this account, you senseless cunt.",
b"You used false and misleading edit summaries more than once which shows a WP:BATTLEGROUND mentality on your part. This summary claims that you reported vandalism so it is immaterial whether you had or not. In this one you claim that there was consensus on the talk page for your edit when there clearly wasn't AND you claim that you notified an admin which is also a false statement. These are not the actions of someone acting in good faith and they clearly justify the notice on your talk page. |Talk",
b'Very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very greatly appreciated. His Imposingness, the Grand Moff (talk)',
b'GET OFF OF WIKIPEDIA RETARD!',
b'"\n\n Say, Tim, could you take a look at ... \n\n... the Excommunication article. particularly the amish/mennonite sections. this same anonymous IP editor has been trying to do the same to the Mennonite article. he/she is obsessed and persistent. i certainly think there should be some discussion of the pathological use of ""The Ban"" you might find in really conservative menno and amish congregations, but this stuff is far from representative of the denomination. "',
b'While things are protected... \n\nI found it refreshing to edit on the libertarian wiki. I transplanted the last good version of the anarchism article there. You know, with the individualist/collectivist structure that we worked on after the last unprotect. I merged your Individual anarchism and American individual anarchism articles - that was obviously an edit war fork. Anyway, check these out:\n\nAnarchism\n\nIndividualist anarchism'],
dtype=object), array([[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 1, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]))
And we check also the shape. We expect a feature of shape (batch, ) and a target of shape (batch, number of labels).
Code
text train shape: (32,)
text train type: object
label train shape: (32, 6)
label train type: int64
4 Preprocessing
Of course preprocessing! Text is not the type of input a NN can handle. The TextVectorization layer is meant to handle natural language inputs. The processing of each example contains the following steps: 1. Standardize each example (usually lowercasing + punctuation stripping) 2. Split each example into substrings (usually words) 3. Recombine substrings into tokens (usually ngrams) 4. Index tokens (associate a unique int value with each token) 5. Transform each example using this index, either into a vector of ints or a dense float vector.
For more reference, see the documentation at the following link.
Code
text_vectorization = TextVectorization(
max_tokens=config.max_tokens,
standardize="lower_and_strip_punctuation",
split="whitespace",
output_mode="int",
output_sequence_length=config.output_sequence_length,
pad_to_max_tokens=True
)
# prepare a dataset that only yields raw text inputs (no labels)
text_train_ds = train_ds.map(lambda x, y: x)
# adapt the text vectorization layer to the text data to index the dataset vocabulary
text_vectorization.adapt(text_train_ds)This layer is set to: - max_tokens: 20000. It is common for text classification. It is the maximum size of the vocabulary for this layer. - output_sequence_length: 911. See Figure 3 for the reason why. Only valid in "int" mode. - output_mode: outputs integer indices, one integer index per split string token. When output_mode == “int”, 0 is reserved for masked locations; this reduces the vocab size to max_tokens - 2 instead of max_tokens - 1. - standardize: "lower_and_strip_punctuation". - split: on whitespace.
To preserve the original comments as text and also have a tf.data.Dataset in which the text is preprocessed by the TextVectorization function, it is possible to map it to the features of each dataset.
Code
processed_train_ds = train_ds.map(
lambda x, y: (text_vectorization(x), y),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_val_ds = val_ds.map(
lambda x, y: (text_vectorization(x), y),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_test_ds = test_ds.map(
lambda x, y: (text_vectorization(x), y),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)5 Model
5.1 Definition
Define the model using the Functional API.
Code
def get_deeper_lstm_model():
clear_session()
inputs = Input(shape=(None,), dtype=tf.int64, name="inputs")
embedding = Embedding(
input_dim=config.max_tokens,
output_dim=config.embedding_dim,
mask_zero=True,
name="embedding"
)(inputs)
x = Bidirectional(LSTM(256, return_sequences=True, name="bilstm_1"))(embedding)
x = Bidirectional(LSTM(128, return_sequences=True, name="bilstm_2"))(x)
# Global average pooling
x = GlobalAveragePooling1D()(x)
# Add regularization
x = Dropout(0.3)(x)
x = Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
x = LayerNormalization()(x)
outputs = Dense(len(config.labels), activation='sigmoid', name="outputs")(x)
model = Model(inputs, outputs)
model.compile(optimizer='adam', loss="binary_crossentropy", metrics=config.metrics, steps_per_execution=32)
return model
lstm_model = get_deeper_lstm_model()
lstm_model.summary()5.2 Callbacks
Finally, the model has been trained using 2 callbacks: - Early Stopping, to avoid to consume the kaggle GPU time. - Model Checkpoint, to retrieve the best model training information.
5.3 Final preparation before fit
Considering the dataset is imbalanced, to increase the performance we need to calculate the class weight. This will be passed during the training of the model.
class_weight
toxic 0.095900590
severe_toxic 0.009928468
obscene 0.052757858
threat 0.003061800
insult 0.049132042
identity_hate 0.008710911
It is also useful to define the steps per epoch for train and validation dataset. This step is required to avoid to not consume entirely the dataset during the fit, which happened to me.
5.4 Fit
The fit has been done on Kaggle to levarage the GPU. Some considerations about the model:
.repeat()ensure the model sees all the dataset.epocsis set to 100.validation_datahas the same repeat.callbacksare the one defined before.class_weightensure the model is trained using the frequency of each class, because our dataset is imbalanced.steps_per_epochandvalidation_stepsdepend on the use ofrepeat.
Now we can import the model and the history trained on Kaggle.
5.5 Evaluate
Code
# A tibble: 5 × 2
metric value
<chr> <dbl>
1 loss 0.0542
2 precision 0.789
3 recall 0.671
4 auc 0.957
5 f1_score 0.0293
5.6 Predict
For the prediction, the model does not need to repeat the dataset, because it has already been trained on all of the train data. Now it has just to consume the new data to make the prediction.
5.7 Confusion Matrix
The best way to assess the performance of a multi label classification is using a confusion matrix. Sklearn has a specific function to create a multi label classification matrix to handle the fact that there could be multiple labels for one prediction.
5.7.1 Grid Search Cross Validation for best threshold
Grid Search CV is a technique for fine-tuning hyperparameter of a ML model. It systematically search through a set of hyperparamenter values to find the combination which led to the best model performance. In this case, I am using a KFold Cross Validation is a resempling technique to split the data into k consecutive folds. Each fold is used once as a validation while the k - 1 remaining folds are the training set. See the documentation for more information.
The model is trained to optimize the recall. The decision was made because the cost of missing a True Positive is greater than a False Positive. In this case, missing a injurious observation is worst than classifying a clean one as bad.
5.7.2 Confidence threshold and Precision-Recall trade off
Whilst the KFold GDCV technique is usefull to test multiple hyperparameter, it is important to understand the problem we are facing. A multi label deep learning classifier outputs a vector of per-class probabilities. These need to be converted to a binary vector using a confidence threshold.
- The higher the threshold, the less classes the model predicts, increasing model confidence [higher Precision] and increasing missed classes [lower Recall].
- The lower the threshold, the more classes the model predicts, decreasing model confidence [lower Precision] and decreasing missed classes [higher Recall].
Threshold selection mean we have to decide which metric to prioritize, based on the problem we are facing and the relative cost of misduging. We can consider the toxic comment filtering a problem similiar to cancer diagnostic. It is better to predict cancer in people who do not have it [False Positive] and perform further analysis than do not predict cancer when the patient has the disease [False Negative].
I decide to train the model on the F1 score to have a balanced model in both precision and recall and leave to the threshold selection to increase the recall performance.
Moreover, the model has been trained on the macro avarage F1 score, which is a single performance indicator obtained by the mean of the Precision and Recall scores of individual classses.
\[ F1\ macro\ avg = \frac{\sum_{i=1}^{n} F1_i}{n} \]
It is useful with imbalanced classes, because it weights each classes equally. It is not influenced by the number of samples of each classes. This is sette both in the config.metrics and find_optimal_threshold_cv.
5.7.2.1 f1_score
Code
Optimal threshold: 0.15000000000000002
Best score: 0.4788653077945807
Optimal threshold f1 score: 0.15. Best score: 0.4788653.
5.7.2.2 recall_score
Code
Optimal threshold recall: 0.05. Best score: 0.8095814.
5.7.2.3 roc_auc_score
Code
Optimal threshold: 0.05
Best score: 0.8809499649742268
Optimal threshold roc: 0.05. Best score: 0.88095.
5.7.3 Confusion Matrix Plot
Code
# convert probability predictions to predictions
ypred = predictions >= optimal_threshold_recall # .05
ypred = ypred.astype(int)
# create a plot with 3 by 2 subplots
fig, axes = plt.subplots(3, 2, figsize=(15, 15))
axes = axes.flatten()
mcm = multilabel_confusion_matrix(ytrue, ypred)
# plot the confusion matrices for each label
for i, (cm, label) in enumerate(zip(mcm, config.labels)):
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(ax=axes[i], colorbar=False)
axes[i].set_title(f"Confusion matrix for label: {label}")
plt.tight_layout()
plt.show()
5.8 Classification Report
Code
# A tibble: 10 × 5
metrics precision recall `f1-score` support
<chr> <dbl> <dbl> <dbl> <dbl>
1 toxic 0.552 0.890 0.682 2262
2 severe_toxic 0.236 0.917 0.375 240
3 obscene 0.550 0.936 0.692 1263
4 threat 0.0366 0.493 0.0681 69
5 insult 0.471 0.915 0.622 1170
6 identity_hate 0.116 0.720 0.200 207
7 micro avg 0.416 0.896 0.569 5211
8 macro avg 0.327 0.812 0.440 5211
9 weighted avg 0.495 0.896 0.629 5211
10 samples avg 0.0502 0.0848 0.0597 5211
6 Conclusions
The BiLSTM model is optimized to have an high recall is performing good enough to make predictions for each label. Considering the low support for the threat label, the performance is not bad. See Table 2 and Figure 1: the threat label is only 0.27 % of the observations. The model has been optimized for recall because the cost of not identifying a injurious comment as such is higher than the cost of considering a clean comment as injurious.
Possibile improvements could be to increase the number of observations, expecially for the threat one. In general there are too many clean comments. This could be avoided doing an undersampling of the clean comment, which I explicitly avoided to check the performance on the BiLSTM with an imbalanced dataset, leveraging the class weight method.